import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf


def display(display_list, size):
    length = len(display_list)
    plt.figure(figsize=size, dpi=int(100 * length))
    for i in range(length):
        plt.subplot(1, length, i + 1)
        plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
        plt.axis('off')


def create_mask(pred_mask):
    pred_mask = np.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., np.newaxis]
    return pred_mask
